-
Notifications
You must be signed in to change notification settings - Fork 423
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Improve the logpdf
of NegativeBinomial
#1583
Conversation
Codecov ReportBase: 85.60% // Head: 85.94% // Increases project coverage by
Additional details and impacted files@@ Coverage Diff @@
## master #1583 +/- ##
==========================================
+ Coverage 85.60% 85.94% +0.33%
==========================================
Files 127 129 +2
Lines 8044 8080 +36
==========================================
+ Hits 6886 6944 +58
+ Misses 1158 1136 -22
Help us with your feedback. Take ten seconds to tell us how you rate us. Have a feature suggestion? Share it here. ☔ View full report at Codecov. |
Wouldn't it make sense to add a test for this now? |
It's not fixed by the PR alone, so the test would fail. |
Now that JuliaDiff/DiffRules.jl#85 and JuliaStats/LogExpFunctions.jl#57 have been merged, this could be revived, right? |
1798647
to
da74c00
Compare
FYI I applied similar changes also to the reverse-mode rule, and moved computations of constants outside of the pullback. That leads to a significant performance improvement of the pullback function (of course, only matters if it is evaluated multiple times): On masterjulia> using Distributions, ChainRulesCore, BenchmarkTools
julia> d = NegativeBinomial(0.8, 0.5);
julia> x = 1;
julia> @btime rrule($logpdf, $d, $x);
45.421 ns (0 allocations: 0 bytes)
julia> _, pb = rrule(logpdf, d, x);
julia> @btime $pb($(randn()));
36.472 ns (0 allocations: 0 bytes) With this PRjulia> using Distributions, ChainRulesCore, BenchmarkTools
julia> d = NegativeBinomial(0.8, 0.5);
julia> x = 1;
julia> @btime rrule($logpdf, $d, $x);
77.026 ns (0 allocations: 0 bytes)
julia> _, pb = rrule(logpdf, d, x);
julia> @btime $pb($(randn()));
1.501 ns (0 allocations: 0 bytes) |
r, p = params(d) | ||
z = xlogy(r, p) + xlog1py(k, -p) | ||
if iszero(k) | ||
# in this case `log(k + r) + logbeta(r, k + 1) == 0` analytically but unfortunately not numerically |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe add another comment line on why this leads to returning z
?
if iszero(k) | ||
# in this case `log(k + r) + logbeta(r, k + 1) == 0` analytically but unfortunately not numerically | ||
return z | ||
elseif insupport(d, k) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have a preference for early termination with edge cases and finishing with the standard case (as was done before, so a branch for !insupport and the normal computation otherwise
two minor comments, shouldn't be a big deal |
I tried to address your comments but I was not completely sure what you had in mind. Can you check if I changed it correctly? |
I'd like to get this in, so I'll merge since I think I addressed the two comments. I'm happy to open a follow-up PR if something could/should be improved. |
Uses
xlogy
andxlog1py
now which take care of the special casep == 1
automatically. Moreover, for improved numerical accuracy for all cases wherek == 0
(and a bit improved efficiency) there's a new branch fork == 0
.Will fix #1582 once the rules for
xlogy
andxlog1py
in DiffRules and LogExpFunctions are fixed (will open PRs).Edit: Fixes #1582.